#!/usr/bin/env python
# -*- coding:utf-8 -*-


'''
@Author : shouke
'''

import gevent
import logging
import json
import re
import random
from copy import deepcopy
from locust import HttpUser, task, between, constant_pacing,constant
from locust.env import Environment
from locust.stats import stats_printer, stats_history
from locust.log import setup_logging
from locust.contrib.fasthttp import FastHttpUser
from config.config import log_level, slave_bind_host, master_bind_port
from utils.utils import Utils


setup_logging(log_level, None)
logger = logging.getLogger(__name__)

request_response_field_map = {'status':'status', 'status_code':'status_code', 'header':'headers', 'headers':'headers', 'body':'text'}

class User(FastHttpUser):
    chain_weight_list = []

    def __init__(self, environment):
        super().__init__(environment)
        self.client_id = id(self.client)


    def get_chain_id_for_task(self):
        '''获取待执行链路id'''

        if not User.chain_weight_list:
            User.chain_weight_list = deepcopy(self.environment.runner.scenario_chain_weight_list)
        max_chain_weight = max(User.chain_weight_list)
        if max_chain_weight > 0:
            index = User.chain_weight_list.index(max_chain_weight)
            User.chain_weight_list[index] -= 1
            return self.environment.runner.scenario_chain_id_list[index]
        elif max_chain_weight == 0 and min(User.chain_weight_list)==0:
            User.chain_weight_list = deepcopy(self.environment.runner.scenario_chain_weight_list)
            index = User.chain_weight_list.index(max(User.chain_weight_list))
            User.chain_weight_list[index] -= 1
            return self.environment.runner.scenario_chain_id_list[index]


    def on_start(self):
        if self.client_id not in self.environment.runner.user_resource_dict: # 刚启动用户
            self.environment.runner.user_resource_dict[self.client_id]={}
            self.environment.runner.user_resource_dict[self.client_id]['iteration_count'] = 0 # 初始化迭代计数器
            self.environment.runner.user_resource_dict[self.client_id]['func_exec_times'] = {} # 用于记录仅执行一次的函数的运行次数

    def exec_once_only_controller(self, step):
        '''控制仅运行一次'''

        if step['id'] not in self.environment.runner.user_resource_dict[self.client_id]['func_exec_times']:
            self.environment.runner.user_resource_dict[self.client_id]['func_exec_times'][step['id']] = 1

            if step.get('children'):
                self.run_actions(step['children'])

    def send_request(self, step):
        ''' 发送请求 '''
        step = json.dumps(step, ensure_ascii=True, indent=2)
        step = Utils.replace_variable(step, self.environment, self.client_id)
        step = json.loads(step, encoding='utf-8')
        logger.debug('正准备发送请求')
        logger.debug('【请求方法】：%s' % step.get('method'))
        logger.debug('【请求地址】：%s' % step.get('path'))
        logger.debug('【请求头】：%s' % step.get('headers'))
        logger.debug('【请求参数】：%s' % step.get('json') or step.get('data'))
        logger.debug('【请求鉴权】：%s' % step.get('auth'))
        logger.debug('【是否允许重定向】：%s' % step.get('allow_redirects'))
        response = getattr(self.client, step['method'].lower())(step.get('path'),
                                                     name=step.get('name'),
                                                     data=step.get('data'),
                                                     headers=step.get('headers'),
                                                     auth=step.get('auth'),
                                                     catch_response=True,
                                                     stream=step.get('stream'),
                                                     json=step.get('json'),
                                                     allow_redirects=step.get('allow_redirects'))
        logger.debug('【请求响应状态码】：%s' % response.status)
        logger.debug('【请求响应头(未格式化)】：%s' % response.headers)
        logger.debug('【请求响应头】：\n%s' % json.dumps(response.headers, indent=2))
        if response.headers and response.headers['Content-Type'] == 'application/json':
            logger.debug('【请求响应体(未格式化)】：%s' % response.text)
            logger.debug('【请求响应体】：\n%s' % json.dumps(response.json(), indent=2, ensure_ascii=False))
        else:
            logger.debug('【请求响应体】：\n%s' % response.text)
        children = step.get('children')
        if not children:
            logger.debug('该请求未配置断言，正在执行默认断言')
            self.assert_response_default(response)
            return

        assert_response_set = 0 # 标记是否存在响应断言
        for item in children:
            item['response'] = response
            if not assert_response_set and item['action'] == 'assert_response':
                assert_response_set = 1

        if not assert_response_set: # 如果存在响应断言，则执行默认的断言 # 因为捕获了请求，如果不进行断言，统计数据那边会缺失
             self.assert_response_default(response)
        self.run_actions(children)

    def assert_contain(self, target, pattern_list, logic):
        ''' 断言目标对象包含预期模式 '''

        if logic.lower() == 'and':
            result = True
            for pattern in pattern_list:
                if str(target).find(pattern) == -1:
                    return False
        elif logic.lower() == 'or':
            result = False
            for pattern in pattern_list:
                if str(target).find(pattern) != -1:
                    return True
        return result

    def assert_equal(self, target, pattern_list, logic):
        ''' 断言目标对象和预期模式相等 '''

        if logic.lower() == 'and':
            result = True
            for pattern in pattern_list:
                 if target != pattern:
                    return False
        elif logic.lower() == 'or':
            result = False
            for pattern in pattern_list:
                if target == pattern:
                    return True
        return result

    def assert_response_default(self, response):
        ''' 默认响应断言 '''
        if response.status not in [200, 201, 301]:
            response.success()
            response._report_success()
            logger.debug('【请求断言结果】：成功')
        else:
            msg = '状态码(%s)错误' % response.status
            response.failure(msg)
            response._report_failure(msg)
            logger.debug('【请求断言结果】：失败，响应状态码：%s' % response.status)

    def assert_response(self, step):
        ''' 请求响应断言 '''
        response = step.get('response')
        logger.debug('正在执行请求断言')
        logger.debug('【请求断言对象】：%s' % getattr(response, request_response_field_map[step['target'].lower()]))
        logger.debug('【请求断言匹配规则】：%s' % step['rule'])
        logger.debug('【请求断言模式】：%s' % step['patterns'])
        if step['target'].lower() not in request_response_field_map:
            logger.error('【请求断言结果】：断言失败，断言目标对象配置错误，可选值为：%s' % ','.join(list(request_response_field_map.keys())))
            return
        if getattr(User, step['rule'])(self, getattr(response, request_response_field_map[step['target'].lower()]), step['patterns'], step.get('logic') or 'and'):
            response.success()
            response._report_success()
            logger.debug('【请求断言结果】：成功')
        else:
            response.failure(response.text)
            response._report_failure(response.text)
            logger.debug('【请求断言结果】：失败')


    def extract_by_regexp(self, step):
        '''正则表达式提取器'''
        logger.debug('正在执行正则提取')
        response = step.get('response')
        if step['target'].lower() not in request_response_field_map:
            logger.error('【正则提取结果】：提取失败，提取源配置错误，可选值为：%s' % ','.join(list(request_response_field_map.keys())))
            return

        re_match_list = re.findall(step.get('express'), str(getattr(response, request_response_field_map[step['target'].lower()])))
        variable = step.get('refName').strip()
        template = step.get('template')
        match_no = step.get('matchNo')

        logger.debug('【正则提取表达式】：%s' % step.get('express'))
        logger.debug('【正则提取源对象】：%s' % str(getattr(response, request_response_field_map[step['target'].lower()])))
        logger.debug('【提取模板】：%s' % template)
        logger.debug('【匹配数字】：%s' % match_no)
        if not re_match_list:
            logger.warn('【正则提取结果】：匹配结果为空')
            return

        if variable:
            if not match_no: # 随机获取
                match_result_choice = random.sample(re_match_list, 1)
            else:
                match_result_choice = re_match_list[match_no - 1:match_no]

            if match_result_choice:
                group_zero_name = '%s_g0' % variable
                self.environment.runner.user_resource_dict[self.client_id][group_zero_name] = ''
                for item_index, item in enumerate(match_result_choice[0]):
                    group_name = '%s_g%s' % (variable, item_index + 1)
                    self.environment.runner.user_resource_dict[self.client_id][group_name] =  item
                    self.environment.runner.user_resource_dict[self.client_id][group_zero_name] += item

                # 替换模板
                target_group_index_list = re.findall('(\$\s*\d+\s*\$?)', template)
                if target_group_index_list:
                    for item in target_group_index_list:
                        index = item.strip('$').strip()
                        group_name = '%s_g%s' % (variable, index)
                        if group_name in self.environment.runner.user_resource_dict[self.client_id]:
                            template = template.replace(item, self.environment.runner.user_resource_dict[self.client_id][group_name])
                else:
                    template = self.environment.runner.user_resource_dict[self.client_id][group_zero_name]

                self.environment.runner.user_resource_dict[self.client_id][variable] = template
                logger.debug('【正则提取结果】：提取成功，提取值为：%s' % template)
            else:
                logger.debug('【正则提取结果】：提取失败，匹配索引越界')

        else:
            logger.debug('【正则提取结果】：提取失败，未配置引用变量')

    def exec_counter(self, step):
        ''' 运行计数器 '''

        if not step['independently_each_user']:
            if step['refName'] not in self.environment.runner.user_share_resouce_dict:
                self.environment.runner.user_share_resouce_dict[step['refName']] = step['value']
            else:
                self.environment.runner.user_share_resouce_dict[step['refName']] += step['increment']

            self.environment.runner.user_resource_dict[self.client_id][step['refName']] = self.environment.runner.user_share_resouce_dict[step['refName']]
        else:
            step['value'] += step['increment']
            self.environment.runner.user_resource_dict[self.client_id][step['refName']] = step['value']

    def read_csv_file_data(self, step):
        '''读取csv文件数据'''

        file_name = step['fileName']
        file_encoding = step['fileEncoding']
        variable_names = step['variableNames']
        ignore_firstLine = step['ignoreFirstLine']
        delimiter = step['delimiter']
        recycle_on_eof = step['recyleOnEOF']
        independently_each_user = step['independently_each_user']

        variable_names = variable_names.split(delimiter)
        if file_name not in self.environment.runner.file_data_dict:
            self.environment.runner.file_data_dict[file_name] = {}
            f = open(file_name, encoding=file_encoding, mode='r')
            self.environment.runner.file_data_dict[file_name]['file_handler'] = f
            self.environment.runner.file_data_dict[file_name]['read_eof'] = False # 文件最后行是否被读取
            self.environment.runner.file_data_dict[file_name]['counter'] = 0 # 用户取数计数器
            self.environment.runner.file_data_dict[file_name]['ignore_firstLine'] = ignore_firstLine # 数据索引
            self.environment.runner.file_data_dict[file_name]['data'] = []

        else:
            f = self.environment.runner.file_data_dict[file_name]['file_handler']

        if not self.environment.runner.file_data_dict[file_name]['read_eof']:
            line = f.readline()
            if self.environment.runner.file_data_dict[file_name]['ignore_firstLine']: # 忽略第一行
                line = f.readline()
                self.environment.runner.file_data_dict[file_name]['ignore_firstLine'] = False

            temp_dict = {}
            if line and line.strip():
                line = line.strip().split(delimiter)
                for variable_name, variable_value in zip(variable_names, line):
                    variable_name = variable_name.strip()
                    variable_value = variable_value.strip()
                    if variable_name:
                        temp_dict[variable_name] = variable_value
                self.environment.runner.file_data_dict[file_name]['data'].append(temp_dict)
            elif not line:
                 self.environment.runner.file_data_dict[file_name]['read_eof'] = True # 文件已经读取完
        if not independently_each_user:
            if recycle_on_eof: # 循环读取
                self.environment.runner.user_resource_dict[self.client_id].update(self.environment.runner.file_data_dict[file_name]['data'][self.environment.runner.file_data_dict[file_name]['counter'] % len(self.environment.runner.file_data_dict[file_name]['data'])])
                self.environment.runner.file_data_dict[file_name]['counter'] += 1
            else:
                self.environment.runner.user_resource_dict[self.client_id].update(self.environment.runner.file_data_dict[file_name]['data'][len(self.environment.runner.file_data_dict[file_name]['data'])-1])
        else:
            if recycle_on_eof:
                self.environment.runner.user_resource_dict[self.client_id].update(self.environment.runner.file_data_dict[file_name]['data'][self.environment.runner.user_resource_dict[self.client_id]['iteration_count'] % len(self.environment.runner.file_data_dict[file_name]['data'])])
            else:
                self.environment.runner.user_resource_dict[self.client_id].update(self.environment.runner.file_data_dict[file_name]['data'][min(len(self.environment.runner.file_data_dict[file_name]['data']) - 1, self.environment.runner.user_resource_dict[self.client_id]['iteration_count'])])

    def run_actions(self, actions):
        ''' 执行一系列动作 '''
        for step in actions:
            logger.debug('【执行步骤】：%s'% step['name'])
            logger.debug('【步骤配置】：%s' % str(step))
            action = step['action'].lower()
            getattr(self, action)(step)

    @task
    def locust_load_test_task(self):
        chain = self.environment.runner.target_scenario[self.get_chain_id_for_task()]['chain']
        self.run_actions(chain)
        self.environment.runner.user_resource_dict[self.client_id]['iteration_count'] += 1

    def on_stop(self):
        pass

if __name__ == '__main__':
    # 设置环境
    env = Environment(user_classes=[User])
    env.create_worker_runner(slave_bind_host, master_bind_port) # 为上述环境创建worker运行器
    # 启动一个greenlet(协程)用于周期性的输出当前性能统计数据
    gevent.spawn(stats_printer(env.stats))

    # 等待greenlets全部结束
    env.runner.greenlet.join()



